import torch
y = torch.tensor([  # Create communities.
    1, 1, 1, 1, 3, 3, 3, 1, 0, 1, 3, 1, 1, 1, 0, 0, 3, 1, 0, 1, 0, 1,
    0, 0, 2, 2, 0, 0, 2, 0, 0, 2, 0, 0
])
train_mask = torch.zeros(y.size(0), dtype=torch.bool)
print(train_mask)

print((y == 2).nonzero(as_tuple=False))

for i in range(int(y.max()) + 1):
    train_mask[(y == i).nonzero(as_tuple=False)[0]] = True

print(train_mask)
# 找到第一个分类是该类型的索引的位置，给True剩下的给 False